--- title: Cross-validation date: 2019-01-28T13:30:00-06:00 # Schedule page publish date. draft: false type: docs bibliography: [../../static/bib/sources.bib] csl: [../../static/bib/apa.csl] link-citations: true menu: notes: parent: Resampling methods weight: 1 ---
library(tidyverse)
library(tidymodels)
library(magrittr)
library(here)
library(rcfss)

set.seed(1234)
theme_set(theme_minimal())

Resampling methods are essential to test and evaluate statistical models. Because you likely do not have the resources or capabilities to repeatedly sample from your population of interest, instead you can repeatedly draw from your original sample to obtain additional information about your model. For instance, you could repeatedly draw samples from your data, estimate a linear regression model on each sample, and then examine how the estimated model differs across each sample. This allows you to assess the variability and stability of your model in a way not possible if you can only fit the model once.

There are two major types of resampling methods we will consider:

1 Training/test set split

In most modeling situations, we can immediately partition the dataset into a training set and a test set. The training set will be used for model construction, and the test set will be used to evaluate the performance of the final model. This is most important – while you can reuse the training set many times to build different statistical models, you can only use the test set of data once. If you reuse it, you introduce data leakage into your modeling process and no longer have unbiased estimates of the test error. This is why collaborative platforms such as Kaggle hold back a portion of the dataset in their competitions. You can use the training set to build the strongest performing model, but you cannot tune your model based on the test error because you do not have access to it.

2 Validation set

Even accounting for the training/test set split, one issue with using the same data to both fit and evaluate our model is that we will bias our model towards fitting the data that we have. We may fit our function to create the results we expect or desire, rather than the “true” function. Instead, we can further split our training set into distinct training and validation sets. The training set can be used repeatedly to train different models. We then use the validation set to evaluate the model’s performance, generating metrics such as the mean squared error (MSE) or the error rate. Unlike the test set, we are permitted to use the validation set multiple times. The important thing is that we do not use the validation set to train or fit the model, only evaluate its performance after it has been fit.

2.1 Regression

Here we will examine the relationship between horsepower and car mileage in the Auto dataset (found in library(ISLR)):

library(ISLR)

Auto <- as_tibble(Auto)
Auto
## # A tibble: 392 x 9
##      mpg cylinders displacement horsepower weight acceleration  year origin
##    <dbl>     <dbl>        <dbl>      <dbl>  <dbl>        <dbl> <dbl>  <dbl>
##  1    18         8          307        130   3504         12      70      1
##  2    15         8          350        165   3693         11.5    70      1
##  3    18         8          318        150   3436         11      70      1
##  4    16         8          304        150   3433         12      70      1
##  5    17         8          302        140   3449         10.5    70      1
##  6    15         8          429        198   4341         10      70      1
##  7    14         8          454        220   4354          9      70      1
##  8    14         8          440        215   4312          8.5    70      1
##  9    14         8          455        225   4425         10      70      1
## 10    15         8          390        190   3850          8.5    70      1
## # … with 382 more rows, and 1 more variable: name <fct>
ggplot(Auto, aes(horsepower, mpg)) +
  geom_point()

The relationship does not appear to be strictly linear:

ggplot(Auto, aes(horsepower, mpg)) +
  geom_point() +
  geom_smooth(method = "lm", se = FALSE)

Perhaps by adding quadratic terms to the linear regression we could improve overall model fit. To evaluate the model, we will split the data into a training set and validation set,1 estimate a series of higher-order models, and calculate a test statistic summarizing the accuracy of the estimated mpg. To calculate the accuracy of the model, we will use mean squared error (MSE), defined as

\[MSE = \frac{1}{N} \sum_{i = 1}^{N}{(y_i - \hat{f}(x_i))^2}\]

For this task, first we use rsample::initial_split() to create training and validation sets (using a 50/50 split), then estimate a linear regression model without any quadratic terms.

set.seed(1234)

auto_split <- initial_split(data = Auto, prop = 0.5)
auto_train <- training(auto_split)
auto_test <- testing(auto_split)
auto_lm <- glm(mpg ~ horsepower, data = auto_train)
summary(auto_lm)
## 
## Call:
## glm(formula = mpg ~ horsepower, data = auto_train)
## 
## Deviance Residuals: 
##      Min        1Q    Median        3Q       Max  
## -13.7105   -3.4442   -0.5342    2.6256   15.1015  
## 
## Coefficients:
##              Estimate Std. Error t value Pr(>|t|)    
## (Intercept) 40.057910   1.054798   37.98   <2e-16 ***
## horsepower  -0.157604   0.009402  -16.76   <2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for gaussian family taken to be 24.80151)
## 
##     Null deviance: 11780.6  on 195  degrees of freedom
## Residual deviance:  4811.5  on 194  degrees of freedom
## AIC: 1189.6
## 
## Number of Fisher Scoring iterations: 2

To estimate the MSE for a single partition (i.e. for a training or validation set):

  1. Use broom::augment() to generate predicted values for the data set
  2. Calculate the residuals and square each value
  3. Calculate the mean of all the squared residuals in the data set

For the training set, this would look like:

(train_mse <- augment(auto_lm, newdata = auto_train) %>%
  mse(truth = mpg, estimate = .fitted))
## # A tibble: 1 x 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 mse     standard        24.5

Note the special use of the $%$ pipe operator from the magrittr package. This allows us to directly access columns from the data frame entering the pipe. This is especially useful for integrating non-tidy functions into a tidy operation.

For the validation set:

(test_mse <- augment(auto_lm, newdata = auto_test) %>%
  mse(truth = mpg, estimate = .fitted))
## # A tibble: 1 x 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 mse     standard        23.4

For a strictly linear model, the MSE for the validation set is 23.38. How does this compare to a quadratic model? We can use the poly() function in conjunction with a map() iteration to estimate the MSE for a series of models with higher-order polynomial terms:

# visualize each model
ggplot(Auto, aes(horsepower, mpg)) +
  geom_point(alpha = .1) +
  geom_smooth(aes(color = "1"),
              method = "glm",
              formula = y ~ poly(x, i = 1),
              se = FALSE) +
  geom_smooth(aes(color = "2"),
              method = "glm",
              formula = y ~ poly(x, i = 2),
              se = FALSE) +
  geom_smooth(aes(color = "3"),
              method = "glm",
              formula = y ~ poly(x, i = 3),
              se = FALSE) +
  geom_smooth(aes(color = "4"),
              method = "glm",
              formula = y ~ poly(x, i = 4),
              se = FALSE) +
  geom_smooth(aes(color = "5"),
              method = "glm",
              formula = y ~ poly(x, i = 5),
              se = FALSE) +
  scale_color_brewer(type = "qual", palette = "Dark2") +
  labs(x = "Horsepower",
       y = "MPG",
       color = "Highest-order\npolynomial")

# function to estimate model using training set and generate fit statistics
# using the test set
poly_results <- function(train, test, i) {
  # Fit the model to the training set
  mod <- glm(mpg ~ poly(horsepower, i, raw = TRUE), data = train)
  
  # `augment` will save the predictions with the test data set
  res <- augment(mod, newdata = test) %>%
    mse(truth = mpg, estimate = .fitted)
  
  # Return the test data set with the additional columns
  res
}

# function to return MSE for a specific higher-order polynomial term
poly_mse <- function(i, train, test){
  poly_results(train, test, i) %$%
    mean(.estimate)
}

cv_mse <- tibble(terms = seq(from = 1, to = 5),
                 mse_test = map_dbl(terms, poly_mse, auto_train, auto_test))

ggplot(cv_mse, aes(terms, mse_test)) +
  geom_line() +
  labs(title = "Comparing quadratic linear models",
       subtitle = "Using validation set",
       x = "Highest-order polynomial",
       y = "Mean Squared Error")

Based on the MSE for the validation set, a polynomial model with a quadratic term (\(\text{horsepower}^2\)) produces a lower average error than the standard model. A higher order term such as a fifth-order polynomial leads to an even larger reduction, though increases the complexity of interpreting the model.

2.2 Classification

Recall our efforts to predict passenger survival during the sinking of the Titanic.

library(titanic)
titanic <- as_tibble(titanic_train) %>%
  mutate(Survived = factor(Survived))

titanic %>%
  head() %>%
  knitr::kable()
PassengerId Survived Pclass Name Sex Age SibSp Parch Ticket Fare Cabin Embarked
1 0 3 Braund, Mr. Owen Harris male 22 1 0 A/5 21171 7.2500 S
2 1 1 Cumings, Mrs. John Bradley (Florence Briggs Thayer) female 38 1 0 PC 17599 71.2833 C85 C
3 1 3 Heikkinen, Miss. Laina female 26 0 0 STON/O2. 3101282 7.9250 S
4 1 1 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35 1 0 113803 53.1000 C123 S
5 0 3 Allen, Mr. William Henry male 35 0 0 373450 8.0500 S
6 0 3 Moran, Mr. James male NA 0 0 330877 8.4583 Q
survive_age_woman_x <- glm(Survived ~ Age * Sex, data = titanic,
                           family = binomial)
summary(survive_age_woman_x)
## 
## Call:
## glm(formula = Survived ~ Age * Sex, family = binomial, data = titanic)
## 
## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -1.9401  -0.7136  -0.5883   0.7626   2.2455  
## 
## Coefficients:
##             Estimate Std. Error z value Pr(>|z|)   
## (Intercept)  0.59380    0.31032   1.913  0.05569 . 
## Age          0.01970    0.01057   1.863  0.06240 . 
## Sexmale     -1.31775    0.40842  -3.226  0.00125 **
## Age:Sexmale -0.04112    0.01355  -3.034  0.00241 **
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 964.52  on 713  degrees of freedom
## Residual deviance: 740.40  on 710  degrees of freedom
##   (177 observations deleted due to missingness)
## AIC: 748.4
## 
## Number of Fisher Scoring iterations: 4

We can use the same validation set approach to evaluate the model’s accuracy. For classification models, instead of using MSE we examine the error rate. That is, of all the predictions generated for the test set, what percentage of predictions are incorrect? The goal is to minimize this value as much as possible (ideally, until we make no errors and our error rate is \(0\)).

# function to convert log-odds to probabilities
logit2prob <- function(x){
  exp(x) / (1 + exp(x))
}
# split the data into training and validation sets
titanic_split <- initial_split(data = titanic, prop = 0.5)

# fit model to training data
train_model <- glm(Survived ~ Age * Sex, data = training(titanic_split),
                   family = binomial)
summary(train_model)
## 
## Call:
## glm(formula = Survived ~ Age * Sex, family = binomial, data = training(titanic_split))
## 
## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -2.1511  -0.7346  -0.5386   0.7339   2.2216  
## 
## Coefficients:
##             Estimate Std. Error z value Pr(>|z|)    
## (Intercept)  0.17464    0.41877   0.417 0.676659    
## Age          0.03570    0.01525   2.342 0.019198 *  
## Sexmale     -0.59608    0.56604  -1.053 0.292313    
## Age:Sexmale -0.06833    0.01994  -3.426 0.000612 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 478.37  on 353  degrees of freedom
## Residual deviance: 361.88  on 350  degrees of freedom
##   (92 observations deleted due to missingness)
## AIC: 369.88
## 
## Number of Fisher Scoring iterations: 4
# calculate predictions using validation set
x_test_accuracy <- augment(train_model, newdata = testing(titanic_split)) %>% 
  as_tibble() %>%
  mutate(.prob = logit2prob(.fitted),
         .pred = factor(round(.prob)))

# calculate test error rate
accuracy(x_test_accuracy, truth = Survived, estimate = .pred)
## # A tibble: 1 x 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.783

This interactive model generates an error rate of 21.7%. We could compare this error rate to alternative classification models, either other logistic regression models (using different formulas) or a tree-based method.

2.3 Drawbacks to validation sets

There are two main problems with validation sets:

  1. Validation estimates of the test error rates can be highly variable depending on which observations are sampled into the training and test sets. See what happens if we repeat the sampling, estimation, and validation procedure for the Auto data set:

    mse_variable <- function(Auto){
      auto_split <- initial_split(Auto, prop = 0.5)
      auto_train <- training(auto_split)
      auto_test <- testing(auto_split)
    
      cv_mse <- tibble(terms = seq(from = 1, to = 5),
                           mse_test = map_dbl(terms, poly_mse, auto_train, auto_test))
    
      return(cv_mse)
    }
    
    rerun(10, mse_variable(Auto)) %>%
      bind_rows(.id = "id") %>%
      ggplot(aes(terms, mse_test, color = id)) +
      geom_line() +
      labs(title = "Variability of MSE estimates",
           subtitle = "Using the validation set approach",
           x = "Degree of Polynomial",
           y = "Mean Squared Error") +
      theme(legend.position = "none")

    Depending on the specific training/test split, our MSE varies by up to 5.

  2. If you don’t have a large data set, you’ll have to dramatically shrink the size of your training set. Most statistical learning methods perform better with more observations - if you don’t have enough data in the training set, you might overestimate the error rate in the test set.

3 Leave-one-out cross-validation

An alternative method is leave-one-out cross validation (LOOCV). Like with the validation set approach, you split the data into two parts. However the difference is that you only remove one observation for the test set, and keep all remaining observations in the training set. The statistical learning method is fit on the \(N-1\) training set. You then use the held-out observation to calculate the \(MSE = (y_1 - \hat{y}_1)^2\) which should be an unbiased estimator of the test error. Because this MSE is highly dependent on which observation is held out, we repeat this process for every single observation in the data set. Mathematically, this looks like:

\[CV_{(N)} = \frac{1}{N} \sum_{i = 1}^{N}{MSE_i}\]

This method produces estimates of the error rate that are approximately unbiased and are non-varying for a given dataset, unlike the validation set approach where the MSE estimate is highly dependent on the sampling process for training/test sets. However it can have have variance because the \(N\) “training sets” are so similar to one another. LOOCV is also highly flexible and works with any kind of predictive modeling.

Of course the downside is that this method is computationally difficult. You have to estimate \(N\) different models - if you have a large \(N\) or each individual model takes a long time to compute, you may be stuck waiting a long time for the computer to finish its calculations.

3.1 LOOCV in linear regression

We can use the loo_cv() function in the rsample library to compute the LOOCV of any linear or logistic regression model. It takes a single argument: the data frame being cross-validated. For the Auto dataset, this looks like:

loocv_data <- loo_cv(Auto)
loocv_data
## # Leave-one-out cross-validation 
## # A tibble: 392 x 2
##    splits          id        
##    <list>          <chr>     
##  1 <split [391/1]> Resample1 
##  2 <split [391/1]> Resample2 
##  3 <split [391/1]> Resample3 
##  4 <split [391/1]> Resample4 
##  5 <split [391/1]> Resample5 
##  6 <split [391/1]> Resample6 
##  7 <split [391/1]> Resample7 
##  8 <split [391/1]> Resample8 
##  9 <split [391/1]> Resample9 
## 10 <split [391/1]> Resample10
## # … with 382 more rows

Each element of loocv_data$splits is an object of class rsplit. This is essentially an efficient container for storing both the analysis data (i.e. the training data set) and the assessment data (i.e. the validation data set). If we print the contents of a single rsplit object:

first_resample <- loocv_data$splits[[1]]
first_resample
## <391/1/392>

This tells us there are 391 observations in the analysis set, 1 observation in the assessment set, and the original data set contained 392 observations. To extract the analysis/assessment sets, use analysis() or assessment() respectively:

training(first_resample)
## # A tibble: 391 x 9
##      mpg cylinders displacement horsepower weight acceleration  year origin
##    <dbl>     <dbl>        <dbl>      <dbl>  <dbl>        <dbl> <dbl>  <dbl>
##  1    18         8          307        130   3504         12      70      1
##  2    15         8          350        165   3693         11.5    70      1
##  3    18         8          318        150   3436         11      70      1
##  4    16         8          304        150   3433         12      70      1
##  5    17         8          302        140   3449         10.5    70      1
##  6    15         8          429        198   4341         10      70      1
##  7    14         8          454        220   4354          9      70      1
##  8    14         8          440        215   4312          8.5    70      1
##  9    14         8          455        225   4425         10      70      1
## 10    15         8          390        190   3850          8.5    70      1
## # … with 381 more rows, and 1 more variable: name <fct>
assessment(first_resample)
## # A tibble: 1 x 9
##     mpg cylinders displacement horsepower weight acceleration  year origin
##   <dbl>     <dbl>        <dbl>      <dbl>  <dbl>        <dbl> <dbl>  <dbl>
## 1    25         4          113         95   2228           14    71      3
## # … with 1 more variable: name <fct>

Given this new loocv_data data frame, we write a function that will, for each resample:

  1. Obtain the analysis data set (i.e. the \(N-1\) training set)
  2. Fit a linear regression model
  3. Predict the test data (also known as the assessment data, the \(1\) test set) using the broom package
  4. Determine the MSE for each sample
holdout_results <- function(splits) {
  # Fit the model to the N-1
  mod <- glm(mpg ~ horsepower, data = analysis(splits))
  
  # Save the heldout observation
  holdout <- assessment(splits)
  
  # `augment` will save the predictions with the holdout data set
  res <- augment(mod, newdata = holdout) %>%
    # calculate the metric
    mse(truth = mpg, estimate = .fitted)
  
  # Return the metrics
  res
}

This function works for a single resample:

holdout_results(loocv_data$splits[[1]])
## # A tibble: 1 x 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 mse     standard     0.00355

To compute the MSE for each heldout observation (i.e. estimate the test MSE for each of the \(N\) observations), we use the map() function from the purrr package to estimate the model for each training test, then calculate the MSE for each observation in each test set:

loocv_data_poly1 <- loocv_data %>%
  mutate(results = map(splits, holdout_results)) %>%
  unnest(results) %>%
  spread(.metric, .estimate)
loocv_data_poly1
## # A tibble: 392 x 4
##    splits          id         .estimator      mse
##    <list>          <chr>      <chr>         <dbl>
##  1 <split [391/1]> Resample1  standard    0.00355
##  2 <split [391/1]> Resample2  standard    1.25   
##  3 <split [391/1]> Resample3  standard   19.6    
##  4 <split [391/1]> Resample4  standard    2.42   
##  5 <split [391/1]> Resample5  standard   16.7    
##  6 <split [391/1]> Resample6  standard   97.0    
##  7 <split [391/1]> Resample7  standard   57.7    
##  8 <split [391/1]> Resample8  standard    1.77   
##  9 <split [391/1]> Resample9  standard   15.3    
## 10 <split [391/1]> Resample10 standard   24.2    
## # … with 382 more rows

Now we can compute the overall LOOCV MSE for the data set by calculating the mean of the mse column:

loocv_data_poly1 %>%
  summarize(mse = mean(mse))
## # A tibble: 1 x 1
##     mse
##   <dbl>
## 1  24.2

We can also use this method to compare the optimal number of polynomial terms as before.

# modified function to estimate model with varying highest order polynomial
holdout_results <- function(splits, i) {
  # Fit the model to the N-1
  mod <- glm(mpg ~ poly(horsepower, i), data = analysis(splits))
  
  # `augment` will save the predictions with the holdout data set
  res <- augment(mod, newdata = assessment(splits)) %>%
    # calculate the metric
    mse(truth = mpg, estimate = .fitted)
  
  # Return the assessment data set with the additional columns
  res
}

# function to return MSE for a specific higher-order polynomial term
poly_mse <- function(i, loocv_data){
  loocv_mod <- loocv_data %>%
    mutate(results = map(splits, holdout_results, i)) %>%
    unnest(results) %>%
    spread(.metric, .estimate)
  
  mean(loocv_mod$mse)
}

cv_mse <- tibble(terms = seq(from = 1, to = 5),
                 mse_loocv = map_dbl(terms, poly_mse, loocv_data))
cv_mse
## # A tibble: 5 x 2
##   terms mse_loocv
##   <int>     <dbl>
## 1     1      24.2
## 2     2      19.2
## 3     3      19.3
## 4     4      19.4
## 5     5      19.0
ggplot(cv_mse, aes(terms, mse_loocv)) +
  geom_line() +
  labs(title = "Comparing quadratic linear models",
       subtitle = "Using LOOCV",
       x = "Highest-order polynomial",
       y = "Mean Squared Error")

And arrive at a similar conclusion. There may be a very marginal advantage to adding a fifth-order polynomial, but not substantial enough for the additional complexity over a mere second-order polynomial.

3.2 LOOCV in classification

Let’s verify the error rate of our interactive terms model for the Titanic data set:

# function to generate assessment statistics for titanic model
holdout_results <- function(splits) {
  # Fit the model to the N-1
  mod <- glm(Survived ~ Age * Sex, data = analysis(splits),
             family = binomial)
  
  # `augment` will save the predictions with the holdout data set
  res <- augment(mod, newdata = assessment(splits)) %>% 
    as_tibble() %>%
    mutate(.prob = logit2prob(.fitted),
           .pred = round(.prob))
  
  # Return the assessment data set with the additional columns
  res
}

titanic_loocv <- loo_cv(titanic) %>%
  mutate(results = map(splits, holdout_results)) %>%
  unnest(results) %>%
  mutate(.pred = factor(.pred)) %>%
  group_by(id) %>%
  accuracy(truth = Survived, estimate = .pred)

1 - mean(titanic_loocv$.estimate, na.rm = TRUE)
## [1] 0.219888

In a classification problem, the LOOCV tells us the average error rate based on our predictions. So here, it tells us that the interactive Age * Sex model has a 22% error rate. This is similar to the validation set result (21.7%).

4 \(K\)-fold cross-validation

A less computationally-intensive approach to cross validation is \(K\)-fold cross-validation. Rather than dividing the data into \(N\) groups, one divides the observations into \(K\) groups, or folds, of approximately equal size. The first fold is treated as the validation set, and the model is estimated on the remaining \(K-1\) folds. This process is repeated \(K\) times, with each fold serving as the validation set precisely once. The \(K\)-fold CV estimate is calculated by averaging the MSE values for each fold:

\[CV_{(K)} = \frac{1}{K} \sum_{i = 1}^{K}{MSE_i}\]

As you may have noticed, LOOCV is a special case of \(K\)-fold cross-validation where \(K = N\). More typically researchers will use \(K=5\) or \(K=10\) depending on the size of the data set and the complexity of the statistical model.

4.1 \(K\)-fold CV in linear regression

Let’s go back to the Auto data set. Instead of LOOCV, let’s use 10-fold CV to compare the different polynomial models.

# modified function to estimate model with varying highest order polynomial
holdout_results <- function(splits, i) {
  # Fit the model to the training set
  mod <- glm(mpg ~ poly(horsepower, i), data = analysis(splits))
  
  # Save the heldout observations
  holdout <- assessment(splits)
  
  # `augment` will save the predictions with the holdout data set
  res <- augment(mod, newdata = holdout) %>%
    # calculate the metric
    mse(truth = mpg, estimate = .fitted)
  
  # Return the assessment data set with the additional columns
  res
}

# function to return MSE for a specific higher-order polynomial term
poly_mse <- function(i, vfold_data){
  vfold_mod <- vfold_data %>%
    mutate(results = map(splits, holdout_results, i)) %>%
    unnest(results) %>%
    spread(.metric, .estimate)
  
  mean(vfold_mod$mse)
}

# split Auto into 10 folds
auto_cv10 <- vfold_cv(data = Auto, v = 10)

cv_mse <- tibble(terms = seq(from = 1, to = 5),
                     mse_vfold = map_dbl(terms, poly_mse, auto_cv10))
cv_mse
## # A tibble: 5 x 2
##   terms mse_vfold
##   <int>     <dbl>
## 1     1      24.2
## 2     2      19.3
## 3     3      19.4
## 4     4      19.6
## 5     5      19.3

How do these results compare to the LOOCV values?

auto_loocv <- loo_cv(Auto)

tibble(terms = seq(from = 1, to = 5),
       `10-fold` = map_dbl(terms, poly_mse, auto_cv10),
       LOOCV = map_dbl(terms, poly_mse, auto_loocv)
) %>%
  gather(method, MSE, -terms) %>%
  ggplot(aes(terms, MSE, color = method)) +
  geom_line() +
  labs(title = "MSE estimates",
       x = "Degree of Polynomial",
       y = "Mean Squared Error",
       color = "CV Method")

Pretty much the same results.

4.2 Computational speed of LOOCV vs. \(K\)-fold CV

4.2.1 LOOCV

library(profvis)

profvis({
  tibble(terms = seq(from = 1, to = 5),
             mse_vfold = map_dbl(terms, poly_mse, auto_loocv))
})

4.2.2 10-fold CV

profvis({
  tibble(terms = seq(from = 1, to = 5),
             mse_vfold = map_dbl(terms, poly_mse, auto_cv10))
})

On my machine, 10-fold CV was about 40 times faster than LOOCV. Again, estimating \(K=10\) models is going to be much easier than estimating \(K=392\) models.

4.3 \(K\)-fold CV in logistic regression

You’ve gotten the idea by now, but let’s do it one more time on our interactive Titanic model.

# function to generate assessment statistics for titanic model
holdout_results <- function(splits) {
  # Fit the model to the training set
  mod <- glm(Survived ~ Age * Sex, data = analysis(splits),
             family = binomial)
  
  # `augment` will save the predictions with the holdout data set
  res <- augment(mod, newdata = assessment(splits)) %>% 
    as_tibble() %>%
    mutate(.prob = logit2prob(.fitted),
           .pred = round(.prob))

  # Return the assessment data set with the additional columns
  res
}

titanic_cv10 <- vfold_cv(data = titanic, v = 10) %>%
  mutate(results = map(splits, holdout_results)) %>%
  unnest(results) %>%
  mutate(.pred = factor(.pred)) %>%
  group_by(id) %>%
  accuracy(truth = Survived, estimate = .pred)

1 - mean(titanic_cv10$.estimate, na.rm = TRUE)
## [1] 0.2200643

Not a large difference from the LOOCV approach, but it take much less time to compute.

5 Appropriate value for \(K\)

Ignoring the computational efficiency concerns, why not always estimate cross-validation with \(K=N\)? Or more generally, what is the optimal value for \(K\)? It depends. Well that is not very helpful.

With more explanation, it depends on how we wish to handle the bias-variance tradeoff. LOOCV is a low-bias, high-variance method. That is, it provides unbiased estimates of the test error since each training set contains \(N-1\) observations. This is almost as many observations as contained in the full data set. \(K\)-fold CV for \(K=5\) or \(10\) leads to an intermediate amount of bias, since each training set contains \(\frac{(K-1)N}{K}\) observations. This is fewer than LOOCV, but more than a standard validation set approach with just a single split into training and validation sets. If all we care about is bias, we should prefer LOOCV.

However, recall the contributors to a model’s error:

\[\text{Error} = \text{Irreducible Error} + \text{Bias}^2 + \text{Variance}\]

We also should be concerned with the variance of the model. LOOCV has a higher variance than \(K\)-fold with \(K < N\). When we perform LOOCV, we are averaging the outputs of \(N\) fitted models which are trained on nearly entirely identical sets of observations. The results will be highly correlated with one another. In contrast, \(K\)-fold CV with \(K < N\) averages the output of \(K\) fitted models that are less correlated with one another, since the data sets are not as identical. Since the mean of many highly correlated quantities has higher variance than the mean of quantities with less correlation, the test error estimate from LOOCV has higher variance than the test error estimate from \(K\)-fold CV.

Given these considerations, a typical approach uses \(K=5\) or \(K=10\). Empirical research (see Breiman and Spector (1992), Kohavi and others (1995)) shows that cross-validation with these number of folds suffers neither excessively high bias nor excessively high variance.

6 Variations on cross-validation

To ensure each set is approximately similar to one another in every important aspect, we use random sampling without replacement to partition the data set. Alternative approaches include:

7 Session Info

devtools::session_info()
## ─ Session info ──────────────────────────────────────────────────────────
##  setting  value                       
##  version  R version 3.5.2 (2018-12-20)
##  os       macOS Mojave 10.14.2        
##  system   x86_64, darwin15.6.0        
##  ui       X11                         
##  language (EN)                        
##  collate  en_US.UTF-8                 
##  ctype    en_US.UTF-8                 
##  tz       America/Chicago             
##  date     2019-01-25                  
## 
## ─ Packages ──────────────────────────────────────────────────────────────
##  package       * version    date       lib
##  assertthat      0.2.0      2017-04-11 [2]
##  backports       1.1.3      2018-12-14 [2]
##  base64enc       0.1-3      2015-07-28 [2]
##  bayesplot       1.6.0      2018-08-02 [2]
##  bindr           0.1.1      2018-03-13 [2]
##  bindrcpp      * 0.2.2      2018-03-29 [1]
##  blogdown        0.9.4      2018-11-26 [1]
##  bookdown        0.9        2018-12-21 [1]
##  broom         * 0.5.1      2018-12-05 [2]
##  callr           3.1.1      2018-12-21 [2]
##  cellranger      1.1.0      2016-07-27 [2]
##  class           7.3-15     2019-01-01 [2]
##  cli             1.0.1      2018-09-25 [1]
##  codetools       0.2-16     2018-12-24 [2]
##  colorspace      1.3-2      2016-12-14 [2]
##  colourpicker    1.0        2017-09-27 [2]
##  crayon          1.3.4      2017-09-16 [2]
##  crosstalk       1.0.0      2016-12-21 [2]
##  desc            1.2.0      2018-05-01 [2]
##  devtools        2.0.1      2018-10-26 [1]
##  dials         * 0.0.2      2018-12-09 [1]
##  digest          0.6.18     2018-10-10 [1]
##  dplyr         * 0.7.8      2018-11-10 [1]
##  DT              0.5        2018-11-05 [2]
##  dygraphs        1.1.1.6    2018-07-11 [2]
##  evaluate        0.12       2018-10-09 [2]
##  forcats       * 0.3.0      2018-02-19 [2]
##  fs              1.2.6      2018-08-23 [1]
##  generics        0.0.2      2018-11-29 [1]
##  ggplot2       * 3.1.0      2018-10-25 [1]
##  ggridges        0.5.1      2018-09-27 [2]
##  glue            1.3.0      2018-07-17 [2]
##  gower           0.1.2      2017-02-23 [2]
##  gridExtra       2.3        2017-09-09 [2]
##  gtable          0.2.0      2016-02-26 [2]
##  gtools          3.8.1      2018-06-26 [2]
##  haven           2.0.0      2018-11-22 [2]
##  here          * 0.1        2017-05-28 [2]
##  hms             0.4.2      2018-03-10 [2]
##  htmltools       0.3.6      2017-04-28 [1]
##  htmlwidgets     1.3        2018-09-30 [2]
##  httpuv          1.4.5.1    2018-12-18 [2]
##  httr            1.4.0      2018-12-11 [2]
##  igraph          1.2.2      2018-07-27 [2]
##  infer         * 0.4.0      2018-11-15 [1]
##  inline          0.3.15     2018-05-18 [2]
##  ipred           0.9-8      2018-11-05 [1]
##  janeaustenr     0.1.5      2017-06-10 [2]
##  jsonlite        1.6        2018-12-07 [2]
##  knitr           1.21       2018-12-10 [2]
##  later           0.7.5      2018-09-18 [2]
##  lattice         0.20-38    2018-11-04 [2]
##  lava            1.6.4      2018-11-25 [2]
##  lazyeval        0.2.1      2017-10-29 [2]
##  lme4            1.1-19     2018-11-10 [2]
##  loo             2.0.0      2018-04-11 [2]
##  lubridate       1.7.4      2018-04-11 [2]
##  magrittr      * 1.5        2014-11-22 [2]
##  markdown        0.9        2018-12-07 [2]
##  MASS            7.3-51.1   2018-11-01 [2]
##  Matrix          1.2-15     2018-11-01 [2]
##  matrixStats     0.54.0     2018-07-23 [2]
##  memoise         1.1.0      2017-04-21 [2]
##  mime            0.6        2018-10-05 [1]
##  miniUI          0.1.1.1    2018-05-18 [2]
##  minqa           1.2.4      2014-10-09 [2]
##  modelr          0.1.2      2018-05-11 [2]
##  munsell         0.5.0      2018-06-12 [2]
##  nlme            3.1-137    2018-04-07 [2]
##  nloptr          1.2.1      2018-10-03 [2]
##  nnet            7.3-12     2016-02-02 [2]
##  parsnip       * 0.0.1      2018-11-12 [1]
##  pillar          1.3.1      2018-12-15 [2]
##  pkgbuild        1.0.2      2018-10-16 [1]
##  pkgconfig       2.0.2      2018-08-16 [2]
##  pkgload         1.0.2      2018-10-29 [1]
##  plyr            1.8.4      2016-06-08 [2]
##  prettyunits     1.0.2      2015-07-13 [2]
##  pROC            1.13.0     2018-09-24 [1]
##  processx        3.2.1      2018-12-05 [2]
##  prodlim         2018.04.18 2018-04-18 [2]
##  promises        1.0.1      2018-04-13 [2]
##  ps              1.3.0      2018-12-21 [2]
##  purrr         * 0.2.5      2018-05-29 [2]
##  R6              2.3.0      2018-10-04 [1]
##  rcfss         * 0.1.5      2019-01-24 [1]
##  Rcpp            1.0.0      2018-11-07 [1]
##  readr         * 1.3.1      2018-12-21 [2]
##  readxl          1.2.0      2018-12-19 [2]
##  recipes       * 0.1.4      2018-11-19 [1]
##  remotes         2.0.2      2018-10-30 [1]
##  reshape2        1.4.3      2017-12-11 [2]
##  rlang           0.3.0.1    2018-10-25 [1]
##  rmarkdown       1.11       2018-12-08 [2]
##  rpart           4.1-13     2018-02-23 [1]
##  rprojroot       1.3-2      2018-01-03 [2]
##  rsample       * 0.0.3      2018-11-20 [1]
##  rsconnect       0.8.12     2018-12-05 [2]
##  rstan           2.18.2     2018-11-07 [2]
##  rstanarm        2.18.2     2018-11-10 [2]
##  rstantools      1.5.1      2018-08-22 [2]
##  rstudioapi      0.8        2018-10-02 [1]
##  rvest           0.3.2      2016-06-17 [2]
##  scales        * 1.0.0      2018-08-09 [1]
##  sessioninfo     1.1.1      2018-11-05 [1]
##  shiny           1.2.0      2018-11-02 [2]
##  shinyjs         1.0        2018-01-08 [2]
##  shinystan       2.5.0      2018-05-01 [2]
##  shinythemes     1.1.2      2018-11-06 [2]
##  SnowballC       0.5.1      2014-08-09 [2]
##  StanHeaders     2.18.0-1   2018-12-13 [2]
##  stringi         1.2.4      2018-07-20 [2]
##  stringr       * 1.3.1      2018-05-10 [2]
##  survival        2.43-3     2018-11-26 [2]
##  testthat        2.0.1      2018-10-13 [2]
##  threejs         0.3.1      2017-08-13 [2]
##  tibble        * 2.0.0      2019-01-04 [2]
##  tidymodels    * 0.0.2      2018-11-27 [1]
##  tidyposterior   0.0.2      2018-11-15 [1]
##  tidypredict     0.2.1      2018-12-20 [1]
##  tidyr         * 0.8.2      2018-10-28 [2]
##  tidyselect      0.2.5      2018-10-11 [1]
##  tidytext        0.2.0      2018-10-17 [1]
##  tidyverse     * 1.2.1      2017-11-14 [2]
##  timeDate        3043.102   2018-02-21 [2]
##  tokenizers      0.2.1      2018-03-29 [2]
##  usethis         1.4.0      2018-08-14 [1]
##  withr           2.1.2      2018-03-15 [2]
##  xfun            0.4        2018-10-23 [1]
##  xml2            1.2.0      2018-01-24 [2]
##  xtable          1.8-3      2018-08-29 [2]
##  xts             0.11-2     2018-11-05 [2]
##  yaml            2.2.0      2018-07-25 [2]
##  yardstick     * 0.0.2      2018-11-05 [1]
##  zoo             1.8-4      2018-09-19 [2]
##  source                           
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  Github (rstudio/blogdown@b2e1ed4)
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.2)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.2)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.1)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.1)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.2)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.2)                   
##  CRAN (R 3.5.2)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.2)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.2)                   
##  CRAN (R 3.5.1)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.1)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  local                            
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.2)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.2)                   
##  CRAN (R 3.5.1)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
##  CRAN (R 3.5.0)                   
## 
## [1] /Users/soltoffbc/Library/R/3.5/library
## [2] /Library/Frameworks/R.framework/Versions/3.5/Resources/library

8 References

Bergmeir, Christoph, and José M Benítez. 2012. “On the Use of Cross-Validation for Time Series Predictor Evaluation.” Information Sciences 191. Elsevier: 192–213.

Breiman, Leo, and Philip Spector. 1992. “Submodel Selection and Evaluation in Regression. the X-Random Case.” International Statistical Review/Revue Internationale de Statistique. JSTOR, 291–319.

Friedman, Jerome, Trevor Hastie, and Robert Tibshirani. 2001. The Elements of Statistical Learning. Vol. 1. 10. Springer series in statistics New York, NY, USA: https://web.stanford.edu/~hastie/ElemStatLearn/.

James, Gareth, Daniela Witten, Trevor Hastie, and Robert Tibshirani. 2013. An Introduction to Statistical Learning. Vol. 112. Springer. http://www-bcf.usc.edu/~gareth/ISL/.

Kohavi, Ron, and others. 1995. “A Study of Cross-Validation and Bootstrap for Accuracy Estimation and Model Selection.” In Ijcai, 14:1137–45. 2. Montreal, Canada.


  1. For educational purposes, here we will omit the test set. In a real-world situation, we would first partition out a test set of data.

  2. The actual value you use is irrelevant. Just be sure to set it in the script, otherwise R will randomly pick one each time you start a new session.

  3. The default family for glm() is gaussian(), or the Gaussian distribution. You probably know it by its other name, the Normal distribution.